import numpy as np
import scipy as sp
import torch
import torch.nn as nn
import torch.optim as optim

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev)
print(dev)

class Network(nn.Module):
    def __init__(self, dim, hidden_size=100):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(dim, 1,False)
    def forward(self, x):
        return self.fc1(x)

class ConservativeLinearUCB:
    def __init__(self, dim, lamdba=1, nu=1, hidden=100,alpha = 0.1):
        self.func = Network(dim, hidden_size=hidden).to(device)
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.U = lamdba * torch.eye(self.total_param).to(device)
        self.nu = nu
        self.alpha = alpha
        self.d = dim
        self.all_rewards = []
        self.baseline_rewards = []
        self.z = torch.zeros(self.d).to(device)
        self.selected_baseline_rewards = []


    def select(self, context, rewards, true_rewards, baseline_arm):
        tensor = torch.from_numpy(context).float().to(device)
        mu = self.func(tensor)
        g_list = []
        sampled = []

        ave_rew = 0
        for i,fx in enumerate(mu):
            self.func.zero_grad()
            fx.backward(retain_graph=True)
            g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
            #g_list.append(g)
            sigma2 = self.nu * torch.sqrt(torch.matmul(tensor[i],torch.matmul(torch.linalg.inv(self.U),tensor[i])))#self.lamdba * self.nu * g * g / self.U
            sigma = torch.sqrt(torch.sum(sigma2))
            g_list.append(torch.outer(tensor[i],tensor[i]))
            sample_r = fx.item() - sigma2.item()

            sampled.append(sample_r)

            ave_rew += sample_r
        
        arm = np.argmin(sampled)
        
        self.baseline_rewards.append(true_rewards[baseline_arm])

        ucb_rewards = self.func(self.z + tensor[arm]) + torch.sqrt(torch.matmul(self.z + tensor[i],torch.matmul(torch.linalg.inv(self.U),self.z + tensor[i])))
            
        if ucb_rewards + sum(self.selected_baseline_rewards) <= (1+self.alpha)*sum(self.baseline_rewards):
                self.all_rewards.append(rewards[arm])
                self.U += g_list[arm]
                self.z += tensor[arm]
                return arm
        self.all_rewards.append(true_rewards[baseline_arm])
        self.selected_baseline_rewards.append(true_rewards[baseline_arm])
        return 'baseline'

    def update(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)

    def train(self):

        optimizer = optim.SGD(self.func.parameters(), lr=1e-2, weight_decay=self.lamdba)
        length = len(self.reward)
        if length == 0:
            return 0
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                delta = self.func(c.to(device)) - r
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= 1000:
                    return tot_loss / 1000
            if batch_loss / length <= 1e-3:
                return batch_loss / length



class NeuralUCBDiag:
    def __init__(self, dim, lamdba=1, nu=1, hidden=100):
        self.func = Network(dim, hidden_size=hidden).cuda()
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.U = lamdba * torch.ones((self.total_param,)).cuda()
        self.nu = nu

    def select(self, context):
        tensor = torch.from_numpy(context).float().cuda()
        mu = self.func(tensor)
        g_list = []
        sampled = []
        ave_sigma = 0
        ave_rew = 0
        for fx in mu:
            self.func.zero_grad()
            fx.backward(retain_graph=True)
            g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
            g_list.append(g)
            sigma2 = self.lamdba * self.nu * g * g / self.U
            sigma = torch.sqrt(torch.sum(sigma2))

            sample_r = fx.item() + sigma.item()

            sampled.append(sample_r)
            ave_sigma += sigma.item()
            ave_rew += sample_r
        arm = np.argmin(sampled)
        self.U += g_list[arm] * g_list[arm]
        return arm, g_list[arm].norm().item(), ave_sigma, ave_rew

    def train(self, context, reward):
        self.context_list.append(torch.from_numpy(context.reshape(1, -1)).float())
        self.reward.append(reward)
        optimizer = optim.SGD(self.func.parameters(), lr=1e-2, weight_decay=self.lamdba)
        length = len(self.reward)
        index = np.arange(length)
        np.random.shuffle(index)
        cnt = 0
        tot_loss = 0
        while True:
            batch_loss = 0
            for idx in index:
                c = self.context_list[idx]
                r = self.reward[idx]
                optimizer.zero_grad()
                delta = self.func(c.cuda()) - r
                loss = delta * delta
                loss.backward()
                optimizer.step()
                batch_loss += loss.item()
                tot_loss += loss.item()
                cnt += 1
                if cnt >= 1000:
                    return tot_loss / 1000
            if batch_loss / length <= 1e-3:
                return batch_loss / length